# Coder: Wenxin Xu
# Github: https://github.com/wenxinxu/resnet_in_tensorflow
# ==============================================================================
'''
This is the resnet structure
'''
import numpy as np
from hyper_parameters import *
from matplotlib import pyplot as plt
import time


BN_EPSILON = 0.001
def wtanh(x, Th = 1e-1):
    
    return Th*tf.tanh(x/Th)

def Th_decay(t,mode='default'):
    if mode=='default':
        return inv_decay(t)
    
def inv_decay(t, A=50*20, b=0.1, c=0.01):
    # when t=A, output=(b+c)/2
    return A*(b-c)/(t+A)+c


# class sr_opt(tf.train.GradientDescentOptimizer):
#     def __init__(self,var_list=None):
        
# class sr_opt(tf.train.GradientDescentOptimizer):
class sr_opt(tf.train.MomentumOptimizer):
    def __init__(self,sr_mode=['RNNprop','sgd'][0],var_list=None):

        super(sr_opt,self).__init__(1., momentum=0.9)  # ???


        self.sr_mode = sr_mode
        self.N_pre = 20
        if var_list is None:
            self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        else:
            self.var_list = var_list
        self.shapes = [v.shape.as_list() for v in self.var_list]

        self.past_gradients = []  # [i_layer][i_pre] 索引到一个形状不确定的tensor
        for sp in self.shapes:
#            sp_ = [self.N_pre]+sp
            holder_layeri = []
            for i in range(self.N_pre):
                holder_layeri.append(tf.Variable(np.zeros(sp,dtype=np.float32)))
            self.past_gradients.append(holder_layeri)
        if self.sr_mode=='RNNprop':
            self.beta1 = 0.9
            self.beta2 = 0.999
            self.mt = []
            self.vt = []
            for sp in self.shapes:
                # self.mt.append(tf.Variable(np.zeros(sp,dtype=np.float32)))
                self.mt.append(tf.zeros(sp,dtype=tf.float32))
                self.vt.append(tf.zeros(sp,dtype=tf.float32))
            
        self.past_grads = []
        return

    def minimize(self,loss,**w):

#        apply_gradients:  https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/Optimizer#methods

#        compute_gradients:  https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/Optimizer#compute_gradients
        # var_list = w['var_list'] if 'var_list' in w.keys() else None
        self.global_step = w['global_step'] if 'global_step' in w.keys() else None  
        grads_vars_list = self.compute_gradients(loss,var_list=self.var_list)
        grads_vars_SR = self.SR_proc(grads_vars_list)
        # print(global_step )
        apply_grad_op = self.apply_gradients(grads_vars_SR, self.global_step)
        
        update_buff_op = []
        for il, grad_layeri in enumerate(self.past_gradients):
            for i in range(self.N_pre-1, 0, -1):  # 19,...,1.
                assop = tf.assign(grad_layeri[i], grad_layeri[i-1])
                update_buff_op.append(assop)
            true_grad_layer_i = grads_vars_list[il][0] # ?????
            update_buff_op.append(tf.assign(grad_layeri[0], true_grad_layer_i))

        op = [update_buff_op, apply_grad_op]
        return op
        
    def SR_proc(self, grads_vars_list):
        grads_vars_SR = []
        for il, (grad,var) in enumerate(grads_vars_list):
            grad_SR = self.SR_1grad(il, grad)
            grads_vars_SR.append( (grad_SR,var) )
        
        return grads_vars_SR

    def SR_1grad(self, il, grad):
        past_gradients = self.past_gradients[il][:self.N_pre-1]
        cur_past = [grad] + past_gradients  # len=N_pre 的 list，每个元素是shape相同的variable（除第一个以外）或tensor（第一个）
        t=tf.cast(self.global_step+3, tf.float32)

        sr_mode='all-tanh'
        # if sr_mode=='RNNprop':
        #     self.mt[il], self.vt[il] = rms_momentum(grad, self.mt[il], self.vt[il], self.beta1, self.beta2)
        #     mthat = self.mt[il]/(1-tf.math.pow(self.beta1,t))
        #     sqvthat = tf.sqrt( self.vt[il]/(1-tf.math.pow(self.beta2,t)) )
        #     o1 = mthat/(sqvthat + 1e-8)
        #     o2 = grad/(sqvthat + 1e-8)

        #     # update = o1*1e-3
        #     # update = wtanh(o1)*1e-2
            
        #     update = (1.2*wtanh(o1, Th_decay(t))+0.5*wtanh(o2), Th_decay(t))*5e-3  # srv1


            
        if sr_mode=='all-tanh':

            self.mt[il], self.vt[il] = rms_momentum(wtanh(grad,Th_decay(t)), self.mt[il], self.vt[il], self.beta1, self.beta2)
            mthat = self.mt[il]/(1-tf.math.pow(self.beta1,t))
            sqvthat = tf.sqrt( self.vt[il]/(1-tf.math.pow(self.beta2,t)) )
            o1 = mthat/(sqvthat + 1e-8)
            o2 = grad/(sqvthat + 1e-8)

            # update = o1*1e-3
            # update = tf.tanh(o1)*1e-2
            update = 1e-2*wtanh(2*wtanh(o1,Th_decay(t))+0.5*wtanh(o2,Th_decay(t)),Th_decay(t))  # srv1
            # update=self.mt[il]


            
        # elif self.sr_mode=='sgd':
        #     update = grad*1e-3

        
        return update
        
def rms_momentum(grad, m, v, beta_1=0.9, beta_2=0.99):
    m = beta_1 * m + (1. - beta_1) * grad
    v = beta_2 * v + (1. - beta_2) * tf.square(grad)

    return m, v

















def activation_summary(x):
    '''
    :param x: A Tensor
    :return: Add histogram summary and scalar summary of the sparsity of the tensor
    '''
    # tensor_name = x.op.name
    # tf.summary.histogram(tensor_name + '/activations', x)
    # tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(x))
    return 


def create_variables(name, shape, initializer=tf.contrib.layers.xavier_initializer(), is_fc_layer=False):
    '''
    :param name: A string. The name of the new variable
    :param shape: A list of dimensions
    :param initializer: User Xavier as default.
    :param is_fc_layer: Want to create fc layer variable? May use different weight_decay for fc
    layers.
    :return: The created variable
    '''
    
    ## TODO: to allow different weight decay to fully connected layer and conv layer
    regularizer = tf.contrib.layers.l2_regularizer(scale=FLAGS.weight_decay)

    new_variables = tf.get_variable(name, shape=shape, initializer=initializer,
                                    regularizer=regularizer)
    return new_variables


def output_layer(input_layer, num_labels):
    '''
    :param input_layer: 2D tensor
    :param num_labels: int. How many output labels in total? (10 for cifar10 and 100 for cifar100)
    :return: output layer Y = WX + B
    '''
    input_dim = input_layer.get_shape().as_list()[-1]
    fc_w = create_variables(name='fc_weights', shape=[input_dim, num_labels], is_fc_layer=True,
                            initializer=tf.uniform_unit_scaling_initializer(factor=1.0))
    fc_b = create_variables(name='fc_bias', shape=[num_labels], initializer=tf.zeros_initializer())

    fc_h = tf.matmul(input_layer, fc_w) + fc_b
    return fc_h


def batch_normalization_layer(input_layer, dimension):
    '''
    Helper function to do batch normalziation
    :param input_layer: 4D tensor
    :param dimension: input_layer.get_shape().as_list()[-1]. The depth of the 4D tensor
    :return: the 4D tensor after being normalized
    '''
    mean, variance = tf.nn.moments(input_layer, axes=[0, 1, 2])
    beta = tf.get_variable('beta', dimension, tf.float32,
                               initializer=tf.constant_initializer(0.0, tf.float32))
    gamma = tf.get_variable('gamma', dimension, tf.float32,
                                initializer=tf.constant_initializer(1.0, tf.float32))
    bn_layer = tf.nn.batch_normalization(input_layer, mean, variance, beta, gamma, BN_EPSILON)

    return bn_layer


def conv_bn_relu_layer(input_layer, filter_shape, stride):
    '''
    A helper function to conv, batch normalize and relu the input tensor sequentially
    :param input_layer: 4D tensor
    :param filter_shape: list. [filter_height, filter_width, filter_depth, filter_number]
    :param stride: stride size for conv
    :return: 4D tensor. Y = Relu(batch_normalize(conv(X)))
    '''

    out_channel = filter_shape[-1]
    filter = create_variables(name='conv', shape=filter_shape)

    conv_layer = tf.nn.conv2d(input_layer, filter, strides=[1, stride, stride, 1], padding='SAME')
    bn_layer = batch_normalization_layer(conv_layer, out_channel)

    output = tf.nn.relu(bn_layer)
    return output


def bn_relu_conv_layer(input_layer, filter_shape, stride):
    '''
    A helper function to batch normalize, relu and conv the input layer sequentially
    :param input_layer: 4D tensor
    :param filter_shape: list. [filter_height, filter_width, filter_depth, filter_number]
    :param stride: stride size for conv
    :return: 4D tensor. Y = conv(Relu(batch_normalize(X)))
    '''

    in_channel = input_layer.get_shape().as_list()[-1]

    bn_layer = batch_normalization_layer(input_layer, in_channel)
    relu_layer = tf.nn.relu(bn_layer)

    filter = create_variables(name='conv', shape=filter_shape)
    conv_layer = tf.nn.conv2d(relu_layer, filter, strides=[1, stride, stride, 1], padding='SAME')
    return conv_layer



def residual_block(input_layer, output_channel, first_block=False):
    '''
    Defines a residual block in ResNet
    :param input_layer: 4D tensor
    :param output_channel: int. return_tensor.get_shape().as_list()[-1] = output_channel
    :param first_block: if this is the first residual block of the whole network
    :return: 4D tensor.
    '''
    input_channel = input_layer.get_shape().as_list()[-1]

    # When it's time to "shrink" the image size, we use stride = 2
    if input_channel * 2 == output_channel:
        increase_dim = True
        stride = 2
    elif input_channel == output_channel:
        increase_dim = False
        stride = 1
    else:
        raise ValueError('Output and input channel does not match in residual blocks!!!')

    # The first conv layer of the first residual block does not need to be normalized and relu-ed.
    with tf.variable_scope('conv1_in_block'):
        if first_block:
            filter = create_variables(name='conv', shape=[3, 3, input_channel, output_channel])
            conv1 = tf.nn.conv2d(input_layer, filter=filter, strides=[1, 1, 1, 1], padding='SAME')
        else:
            conv1 = bn_relu_conv_layer(input_layer, [3, 3, input_channel, output_channel], stride)

    with tf.variable_scope('conv2_in_block'):
        conv2 = bn_relu_conv_layer(conv1, [3, 3, output_channel, output_channel], 1)

    # When the channels of input layer and conv2 does not match, we add zero pads to increase the
    #  depth of input layers
    if increase_dim is True:
        pooled_input = tf.nn.avg_pool(input_layer, ksize=[1, 2, 2, 1],
                                      strides=[1, 2, 2, 1], padding='VALID')
        padded_input = tf.pad(pooled_input, [[0, 0], [0, 0], [0, 0], [input_channel // 2,
                                                                     input_channel // 2]])
    else:
        padded_input = input_layer

    output = conv2 + padded_input
    return output


def inference_small_conv(inputs, n, reuse):
    batch_norm=True
    batch_size=int(inputs.shape[0])

    with tf.variable_scope('small_conv', reuse=reuse):

      def _conv_activation(x):
          return tf.nn.max_pool(tf.nn.relu(x),
                                ksize=[1, 2, 2, 1],
                                strides=[1, 2, 2, 1],
                                padding="VALID")

      def conv_layer(inputs, strides, c_h, c_w, output_channels, padding, name):
          n_channels = int(inputs.get_shape()[-1])
          with tf.variable_scope(name) as scope:
              kernel1 = tf.get_variable('weights1',
                                        shape=[c_h, c_w, n_channels, output_channels],
                                        dtype=tf.float32,
                                        initializer=tf.random_normal_initializer(stddev=0.01)
                                        )
            
              biases1 = tf.get_variable('biases1', [output_channels], initializer=tf.constant_initializer(0.0))
          inputs = tf.nn.conv2d(inputs, kernel1, [1, strides, strides, 1], padding)
          inputs = tf.nn.bias_add(inputs, biases1)
          if batch_norm:
              inputs = tf.layers.batch_normalization(inputs, training=True)
          inputs = _conv_activation(inputs)
          return inputs

      inputs = conv_layer(inputs, 2, 3, 3, 16, "VALID", 'conv_layer1')
      inputs = conv_layer(inputs, 2, 5, 5, 32, "VALID", 'conv_layer2')
      inputs = tf.reshape(inputs, [batch_size, -1])
      fc_shape2 = int(inputs.get_shape()[1])
      weights = tf.get_variable("fc_weights",
                                shape=[fc_shape2, 10],
                                dtype=tf.float32,
                                initializer=tf.random_normal_initializer(stddev=0.01))
      bias = tf.get_variable("fc_bias",
                             shape=[10, ],
                             dtype=tf.float32,
                             initializer=tf.constant_initializer(0.0))

    return tf.nn.relu(tf.nn.bias_add(tf.matmul(inputs, weights), bias))


def inference_resnet(input_tensor_batch, n, reuse):
    '''
    The main function that defines the ResNet. total layers = 1 + 2n + 2n + 2n +1 = 6n + 2
    :param input_tensor_batch: 4D tensor
    :param n: num_residual_blocks
    :param reuse: To build train graph, reuse=False. To build validation graph and share weights
    with train graph, resue=True
    :return: last layer in the network. Not softmax-ed
    '''

    layers = []
    with tf.variable_scope('conv0', reuse=reuse):
        conv0 = conv_bn_relu_layer(input_tensor_batch, [3, 3, 3, 16], 1)
        activation_summary(conv0)
        layers.append(conv0)

    for i in range(n):
        with tf.variable_scope('conv1_%d' %i, reuse=reuse):
            if i == 0:
                conv1 = residual_block(layers[-1], 16, first_block=True)
            else:
                conv1 = residual_block(layers[-1], 16)
            activation_summary(conv1)
            layers.append(conv1)

    for i in range(n):
        with tf.variable_scope('conv2_%d' %i, reuse=reuse):
            conv2 = residual_block(layers[-1], 32)
            activation_summary(conv2)
            layers.append(conv2)

    for i in range(n):
        with tf.variable_scope('conv3_%d' %i, reuse=reuse):
            conv3 = residual_block(layers[-1], 64)
            layers.append(conv3)
        assert conv3.get_shape().as_list()[1:] == [8, 8, 64]

    with tf.variable_scope('fc', reuse=reuse):
        in_channel = layers[-1].get_shape().as_list()[-1]
        bn_layer = batch_normalization_layer(layers[-1], in_channel)
        relu_layer = tf.nn.relu(bn_layer)
        global_pool = tf.reduce_mean(relu_layer, [1, 2])

        assert global_pool.get_shape().as_list()[-1:] == [64]
        output = output_layer(global_pool, 10)
        layers.append(output)

    return layers[-1]


def test_graph(train_dir='logs'):
    '''
    Run this function to look at the graph structure on tensorboard. A fast way!
    :param train_dir:
    '''
    input_tensor = tf.constant(np.ones([128, 32, 32, 3]), dtype=tf.float32)
    result = inference(input_tensor, 2, reuse=False)
    init = tf.initialize_all_variables()
    sess = tf.Session()
    sess.run(init)
    summary_writer = tf.train.SummaryWriter(train_dir, sess.graph)


def get_restorer(variables_list, filename):
  data = pickle.load(open(filename, "rb"))
  name_list = data.keys()
  restore_holder = {}
  assigns = {}
  for v in variables_list:
  # for name_ in name_list:
    name_ = v.name
    restore_holder[name_] = tf.placeholder(shape=v.get_shape(), dtype=v.dtype)

    assigns[name_] = tf.assign(v, data[name_])

  return assigns, restore_holder
def restore(variables_list, restore_holder, assigns, sess, filename):
  # ops = []
  data = pickle.load(open(filename, "rb"))
  feed = {}
  for v in variables_list:
    name_ = v.name
    feed[restore_holder[name_]] = data[name_]
    # ops.append(assigns[name_])
  sess.run(assigns, feed_dict=feed)
  return

def save(variables_list, sess, filename):

  # to_save = collections.defaultdict(dict)
  # variables_list = snt.get_variables_in_module(network)
  to_save = {}
  for v in variables_list:
    name_ = v.name
    # split = v.name.split(":")[0].split("/")
    # module_name = split[-2]
    # variable_name = split[-1]
    to_save[name_] = v.eval(sess)


  with open(filename, "wb") as f:
    pickle.dump(to_save, f)

  return to_save





class wsaver():
    # 用法完全相同于 tf.train.saver
    def __init__(self,vars_list):
        self.vars_list = vars_list
    def restore(self, sess,ckpt_path):
        import pickle
        with open(ckpt_path, 'rb') as file:
            dic = pickle.load(file)
        assert len(dic)==len(self.vars_list)
        assigns = []
        for i, (k,v) in enumerate(dic.items()):
            assert k==self.vars_list[i].name
        assigns.append( tf.assign(self.vars_list[i],v) )
        sess.run(assigns)
        return

    def save(self,sess,ckpt_path,global_step=[]):
        import pickle
        # t = sess.run(global_step)
        vs = {}
        for v in self.vars_list:
            vs[v.name]=sess.run(v)
        with open(ckpt_path, 'wb') as file:
            pickle.dump(vs,file)
        return
    
    
def lr_schedular(step):
    mode = ['invt1', 'const1', ][1]
    if mode=='invt1':
        initial_learning_rate = 0.1
        decay_step = 200
        decay_rate = 2.
        staircase = True
        return InverseTimeDecay(step, initial_learning_rate, decay_step, decay_rate, staircase=False)
    if mode=='const1':
        return 0.001





def InverseTimeDecay(step, initial_learning_rate, decay_step, decay_rate, staircase=False):
    if staircase:
        return initial_learning_rate / (1 + decay_rate * step / decay_step)
    else:
        return initial_learning_rate / (1 + decay_rate * np.floor(step / decay_step))


def plot_epoch(data, it, total_len, ttl=''):
  # it==0的时候新建临时文件存储数据；it==total_len的时候画图，保存，并删除临时文件
  if it==0:
    if os.path.exists('tmp_plot_epoch-{}.npy'.format(ttl)):
      print('pre-existing plot_epoch file found (not expected)! now removed it.')
      os.remove('tmp_plot_epoch-{}.npy'.format(ttl))
    datas = [data,]
    np.save('tmp_plot_epoch-{}.npy'.format(ttl), datas)
  elif it<=int(total_len)-1:
    datas = np.load('tmp_plot_epoch-{}.npy'.format(ttl)).tolist()
    # print('?',len(datas))
    datas.append(data)
    np.save('tmp_plot_epoch-{}.npy'.format(ttl), datas)
  if it==int(total_len)-1 or (it+1)%20==0:
    
    datas = np.load('tmp_plot_epoch-{}.npy'.format(ttl))
    # os.remove('tmp_plot_epoch-{}.npy'.format(ttl))
    plt.close('all')
    plt.figure()
    plt.plot(np.arange(it+1), datas)
    plt.title(ttl)
    lt2 = time.strftime("%Y-%m-%d--%H_%M_%S", time.localtime())
    os.makedirs('wz_insp', exist_ok=1)

    plt.savefig('./wz_insp/{}___time_{}___{}'.format(ttl,it,lt2))
    # plt.show()
    # plt.close()
  return


